#!/usr/bin/env python3
"""
steer_grid.py — multi-mode grid runner (baseline outside sweep)

Modes swept:
  vec_base      : use vec::baseline_all_steps
  soft_prob     : sample one of top-3 edge vecs by probability per batch
  soft_argmax   : always pick highest-prob edge vec

Global baseline:
  Printed once as [grid-baseline], with NO steering.
"""

from __future__ import annotations
import argparse, os, json, csv, re, warnings
from typing import List, Tuple, Dict, Optional

import numpy as np
import torch
from tqdm.auto import tqdm

# ---- import helpers from steer_eval.py ----
from steer_eval_soft import (
    _dist_init, _dist_is_enabled, _dist_rank, _dist_world, _only_rank0, _shard_list, _pick_device,
    load_preproc, invert_preproc_step, make_schedule, LLMRunner,
    _load_gpqa_diamond_items, _load_hf_dataset_items, set_global_determinism,
    _grade,
)

# ----------------- utils -----------------
def _sanitize(s: str) -> str:
    return re.sub(r'[^A-Za-z0-9_.:-]+', '-', str(s))

def _combo_id(alpha: float, schedule_name: str, layer_list: list[int]) -> str:
    lay = ",".join(map(str, layer_list))
    return f"a{alpha:g}__{_sanitize(schedule_name)}__L{lay}"

def _parse_alphas(s: str) -> List[float]:
    return [float(x) for x in s.split(",") if x.strip()]

def _parse_schedules(s: str) -> List[str]:
    return [x.strip() for x in s.split(",") if x.strip()]

def _parse_layer_sets(s: str) -> List[List[int]]:
    out = []
    for chunk in s.split("|"):
        chunk = chunk.strip()
        if not chunk:
            continue
        if "-" in chunk:
            a, b = chunk.split("-", 1)
            a, b = int(a), int(b)
            out.append(list(range(min(a,b), max(a,b)+1)))
        else:
            out.append([int(x) for x in chunk.split(",") if x.strip()])
    return out

def _reduce_counts(device: str, local_correct: int, local_total: int, local_gen_tokens: int):
    if _dist_is_enabled():
        import torch.distributed as dist
        t = torch.tensor([local_correct, local_total, local_gen_tokens],
                         dtype=torch.long, device=_pick_device(device))
        dist.all_reduce(t, op=dist.ReduceOp.SUM)
        g_acc = (t[0].item() / max(1, t[1].item()))
        return g_acc, t[1].item(), t[2].item()
    else:
        g_acc = (local_correct / max(1, local_total))
        return g_acc, local_total, local_gen_tokens

def _auto_soft_json(stats_npz_path: str) -> Optional[str]:
    cand = [
        os.path.join(os.path.dirname(stats_npz_path), "soft_edges_top3.json"),
        os.path.join(os.path.dirname(stats_npz_path).replace("steer_stats", "steer_stats_last_baseline_soft"), "soft_edges_top3.json"),
        os.path.join(os.path.dirname(stats_npz_path).replace("steer_stats_last_baseline", "steer_stats_last_baseline_soft"), "soft_edges_top3.json"),
    ]
    for p in cand:
        if os.path.isfile(p):
            return p
    return None

def _build_vec_bank_from_soft(stats_npz: str, model_npz: str, soft_json: str, prefix: str):
    """
    Returns:
      vec_bank: List[np.ndarray]  (hidden-space vectors)
      probs:    np.ndarray [K] with sum=1
    """
    with open(soft_json, "r") as f:
        s = json.load(f)
    edges = s.get("edges", [])
    probs = np.asarray(s.get("weights", []), dtype=float)
    if len(edges) == 0 or probs.size == 0:
        raise ValueError(f"No edges/weights in {soft_json}")
    if len(edges) != probs.size:
        raise ValueError(f"Mismatch edges({len(edges)}) vs weights({probs.size}) in {soft_json}")

    z = np.load(stats_npz, allow_pickle=True)
    scaler, pca = load_preproc(model_npz)

    bank, keep = [], []
    for k, (i, j) in enumerate(edges):
        key = f"vec::{prefix}:{i},{j}"
        if key not in z.files:
            if _only_rank0():
                warnings.warn(f"[soft] Missing {key} in {stats_npz}; skipping")
            continue
        bank.append(invert_preproc_step(z[key], scaler, pca))
        keep.append(k)

    if not bank:
        raise ValueError(f"No edge vectors with prefix '{prefix}' in {stats_npz}")

    probs = probs[keep]
    ssum = float(probs.sum())
    probs = probs / ssum if ssum > 0 else np.ones(len(bank))/len(bank)
    return bank, probs

def _eval_batched_select_vec(
    runner: LLMRunner,
    items: List[Tuple[str,str]],
    metric: str,
    schedule: Optional[Dict[int,float]],
    base_seed: int,
    batch_size: int,
    step_aware: bool,
    mode: str,                            # "none", "single", "prob", "argmax"
    vec_hidden_single: Optional[np.ndarray] = None,
    vec_bank: Optional[List[np.ndarray]] = None,
    vec_probs: Optional[np.ndarray] = None,
    regex_answer: Optional[str] = None,
    regex_pred: Optional[str] = None,
    show_progress: bool = False,
):
    """Like evaluate_llm_accuracy_batched but chooses a vec per batch."""
    correct, rows = 0, []
    total_gen_tokens = 0

    it = range(0, len(items), batch_size)
    if show_progress and _only_rank0():
        it = tqdm(it, total=(len(items)+batch_size-1)//batch_size, desc=f"Evaluating[{mode}]", unit="batch")

    rng = np.random.default_rng(base_seed)

    for start in it:
        batch = items[start:start+batch_size]
        prompts = [p for (p, _) in batch]
        golds   = [g for (_, g) in batch]

        # choose vec for THIS batch
        if mode == "none":
            vec_hidden = None
        elif mode == "single":
            vec_hidden = vec_hidden_single
        elif mode in ("prob", "argmax"):
            assert vec_bank and len(vec_bank) > 0
            idx = int(np.argmax(vec_probs)) if mode == "argmax" else int(rng.choice(len(vec_bank), p=vec_probs))
            vec_hidden = vec_bank[idx]
        else:
            raise ValueError(f"unknown mode {mode}")

        gen = torch.Generator(device=_pick_device("auto"))
        gen.manual_seed(int(base_seed))

        preds, gens = runner.generate_batched(
            prompts,
            schedule=schedule if vec_hidden is not None else None,
            vec_hidden=vec_hidden,
            step_aware=step_aware,
            torch_generator=gen
        )

        for j, (pred, gold, gen_tokens) in enumerate(zip(preds, golds, gens)):
            ok = _grade(pred, gold, metric=metric, regex_answer=regex_answer, regex_pred=regex_pred)
            correct += int(ok); total_gen_tokens += int(gen_tokens)
            rows.append({"i": start + j, "prompt": prompts[j], "gold": gold, "pred": pred,
                         "ok": bool(ok), "gen_tokens": int(gen_tokens)})

    acc = correct / max(1, len(items))
    avg_gen_tokens = (total_gen_tokens / max(1, len(items)))
    return acc, rows, total_gen_tokens, avg_gen_tokens

# ----------------- main -----------------
def main():
    ap = argparse.ArgumentParser()
    # inputs
    ap.add_argument("--stats_npz", required=True)
    ap.add_argument("--model_npz", required=True)

    # Strategy + modes
    ap.add_argument("--use_strategy", default="baseline_all_steps",
                    help="vec::<name> used for vec_base mode (default baseline_all_steps).")
    ap.add_argument("--steer_modes", default=None,
                    help="Comma-separated among: vec_base,soft_prob,soft_argmax. Default: all three.")
    ap.add_argument("--soft_json", default=None,
                    help="Path to soft_edges_top3.json (for soft_* modes). If omitted, auto-detect.")
    ap.add_argument("--edge_vec_prefix", default="edge_delta",
                    help="Prefix for edge vec keys in stats npz, e.g., vec::<prefix>:i,j .")

    # runtime/model
    ap.add_argument("--gen_model", default="bespokelabs/Bespoke-Stratos-7B")
    ap.add_argument("--tokenizer", default=None)
    ap.add_argument("--device", default="cuda:0")
    ap.add_argument("--dtype", default="float16", choices=["float16","bfloat16","float32"])
    ap.add_argument("--gen_temperature", type=float, default=0.6)
    ap.add_argument("--gen_top_p", type=float, default=0.95)
    ap.add_argument("--gen_top_k", type=int, default=None)
    ap.add_argument("--min_p", type=float, default=None)
    ap.add_argument("--gen_max_new_tokens", type=int, default=2000)

    # chat opts
    ap.add_argument("--use_nemotron_chat", action="store_true")
    ap.add_argument("--system_text", default="detailed thinking on")
    ap.add_argument("--final_boxed_hint", action="store_true")
    ap.add_argument("--use_qwen_chat", action="store_true")
    ap.add_argument("--qwen_enable_thinking", action="store_true")

    # dataset

    ap.add_argument("--eval_diamond", action="store_true")
    ap.add_argument("--diamond_split", choices=["train","test"], default="train")
    ap.add_argument("--diamond_n", type=int, default=100)
    ap.add_argument("--diamond_seed", type=int, default=None)
    ap.add_argument("--diamond_skip_first", type=int, default=0)

    ap.add_argument("--hf_dataset", default=None)
    ap.add_argument("--hf_config", default=None)
    ap.add_argument("--hf_split", default="test")
    ap.add_argument("--hf_prompt_key", default="question")
    ap.add_argument("--hf_answer_key", default="answer")
    ap.add_argument("--hf_seed", type=int, default=None)
    ap.add_argument("--hf_skip_first", type=int, default=0)
    ap.add_argument("--hf_filter_answer_types", default=None)
    ap.add_argument("--hf_filter_difficulties", default=None)

    ap.add_argument("--max_eval", type=int, default=100)
    ap.add_argument("--metric", choices=["em","numeric","regex"], default="numeric")
    ap.add_argument("--regex_answer", default=None)
    ap.add_argument("--regex_pred", default=None)
    ap.add_argument("--progress", action="store_true")

    # sweep
    ap.add_argument("--alphas", required=True, help="e.g. 0.05,0.1,0.2")
    ap.add_argument("--schedules", required=True, help="must be 'linear' (comma-separated OK)")
    ap.add_argument("--layer_sets", required=True, help="e.g. 24,26,28|22-31|21-26|31")

    # outputs
    ap.add_argument("--out_json", required=True)
    ap.add_argument("--out_csv", default=None)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--details_dir", default=None)

    # batching & step-aware
    ap.add_argument("--batch_size", type=int, default=1)
    ap.add_argument("--step_aware", action="store_true")

    args = ap.parse_args()

    # Expand modes (NO 'baseline' here)
    if args.steer_modes is None:
        modes = ["vec_base", "soft_prob", "soft_argmax"]
    else:
        modes = [m.strip() for m in args.steer_modes.split(",") if m.strip()]
    modes = [m for m in modes if m in {"vec_base","soft_prob","soft_argmax"}]
    if not modes:
        raise SystemExit("[error] No valid steer modes selected.")

    # DDP / determinism
    _dist_init()
    set_global_determinism(args.seed, strict=False)
    if args.details_dir and _only_rank0():
        os.makedirs(args.details_dir, exist_ok=True)

    # Load items
    if args.hf_dataset:
        items_full = None
        if _only_rank0():
            items_full = _load_hf_dataset_items(
                ds_name=args.hf_dataset,
                ds_config=args.hf_config,
                split=args.hf_split,
                prompt_key=args.hf_prompt_key,
                answer_key=args.hf_answer_key,
                n=args.max_eval,
                seed=(args.hf_seed if args.hf_seed is not None else args.seed),
                skip_first=args.hf_skip_first,
                filter_answer_types=(args.hf_filter_answer_types.split(",") if args.hf_filter_answer_types else None),
                filter_difficulties=(args.hf_filter_difficulties.split(",") if args.hf_filter_difficulties else None),
            )
        if _dist_is_enabled():
            obj = [items_full]; torch.distributed.broadcast_object_list(obj, src=0); items_full = obj[0]
    elif args.eval_diamond:
        items_full = _load_gpqa_diamond_items(
            n=args.diamond_n, split=args.diamond_split,
            seed=(args.diamond_seed if args.diamond_seed is not None else args.seed),
            skip_first=args.diamond_skip_first
        )
        if args.metric == "regex": args.metric = "regex"
    else:
        items_full = []

    if not items_full:
        if _only_rank0(): print("[grid] No evaluation items.")
        return

    rank, world = _dist_rank(), _dist_world()
    items = _shard_list(items_full, rank, world)

    # Runner
    runner = LLMRunner(
        model_name=args.gen_model, tokenizer_name=args.tokenizer,
        temperature=args.gen_temperature, top_p=args.gen_top_p,
        max_new_tokens=args.gen_max_new_tokens,
        device=args.device, dtype=args.dtype, top_k=args.gen_top_k,
        use_nemotron_chat=args.use_nemotron_chat,
        system_text=args.system_text,
        final_boxed_hint=args.final_boxed_hint,
        use_qwen_chat=args.use_qwen_chat,
        qwen_enable_thinking=args.qwen_enable_thinking,
        min_p=args.min_p,
    )

    # ---------- Global baseline (pure; only once) ----------
    base_acc, base_rows, base_total_gen_toks, _ = _eval_batched_select_vec(
        runner, items, metric=args.metric,
        schedule=None, base_seed=args.seed, batch_size=args.batch_size, step_aware=False,
        mode="none",
        regex_answer=args.regex_answer, regex_pred=args.regex_pred,
        show_progress=args.progress
    )
    if args.details_dir:
        with open(os.path.join(args.details_dir, f"baseline.rank{rank}.json"), "w") as f:
            json.dump({"rank": rank, "kind": "baseline", "n_local": len(items), "rows": base_rows}, f, indent=2)
    local_base_correct = int(round(base_acc * len(items)))
    g_base_acc, g_total_N, g_base_tok_total = _reduce_counts(
        args.device, local_base_correct, len(items), int(base_total_gen_toks)
    )
    if _only_rank0():
        print(
            f"[grid-baseline] model={args.gen_model} "
            f"| dataset={args.hf_dataset or ('diamond' if args.eval_diamond else 'custom')} "
            f"| seed={args.seed} → acc={g_base_acc:.4f}, N={g_total_N}, gen_toks_total={g_base_tok_total}"
        )

    # ---------- Prep sweep ----------
    alphas = _parse_alphas(args.alphas)
    schedules = _parse_schedules(args.schedules)
    layer_sets = _parse_layer_sets(args.layer_sets)
    results = []

    # ---------- Prepare vecs for vec_base ----------
    vec_hidden_single = None
    if "vec_base" in modes:
        z = np.load(args.stats_npz, allow_pickle=True)
        key = f"vec::{args.use_strategy}"
        if key not in z.files:
            raise SystemExit(f"[error] vec::{args.use_strategy} not found in {args.stats_npz}")
        scaler, pca = load_preproc(args.model_npz)
        vec_hidden_single = invert_preproc_step(z[key], scaler, pca)

    # ---------- Prepare vec bank for soft_* ----------
    vec_bank, vec_probs = None, None
    if "soft_prob" in modes or "soft_argmax" in modes:
        soft_json = args.soft_json or _auto_soft_json(args.stats_npz)
        if not soft_json or not os.path.isfile(soft_json):
            raise SystemExit(f"[error] soft JSON not found. Pass --soft_json or place soft_edges_top3.json next to stats.")
        vec_bank, vec_probs = _build_vec_bank_from_soft(
            args.stats_npz, args.model_npz, soft_json, prefix=args.edge_vec_prefix
        )

    # ---------- Sweep modes (no 'baseline' here) ----------
    for mode in modes:
        sel_mode = {"vec_base": "single", "soft_prob": "prob", "soft_argmax": "argmax"}[mode]
        sweep_iter = [(a,s,ls) for a in alphas for s in schedules for ls in layer_sets]
        if _only_rank0() and args.progress:
            sweep_iter = tqdm(sweep_iter, desc=f"Grid[{mode}]", unit="cfg")

        for alpha, sched_name, layer_list in sweep_iter:
            schedule = make_schedule("linear", layer_list, alpha)

            steer_acc, steer_rows, steer_total_gen_toks, _ = _eval_batched_select_vec(
                runner, items, metric=args.metric,
                schedule=schedule,
                base_seed=args.seed, batch_size=args.batch_size, step_aware=args.step_aware,
                mode=sel_mode,
                vec_hidden_single=vec_hidden_single,
                vec_bank=vec_bank, vec_probs=vec_probs,
                regex_answer=args.regex_answer, regex_pred=args.regex_pred,
                show_progress=False
            )

            combo = _combo_id(alpha, "linear", layer_list)
            if args.details_dir:
                with open(os.path.join(args.details_dir, f"{combo}.{mode}.rank{rank}.json"), "w") as f:
                    json.dump({
                        "rank": rank, "kind": f"steered:{mode}",
                        "alpha": float(alpha), "schedule": "linear", "layers": layer_list,
                        "n_local": len(items), "rows": steer_rows
                    }, f, indent=2)

            local_steer_correct = int(round(steer_acc * len(items)))
            g_steer_acc, gN, g_steer_tok_total = _reduce_counts(
                args.device, local_steer_correct, len(items), int(steer_total_gen_toks)
            )

            if _only_rank0():
                res = {
                    "mode": mode,
                    "alpha": float(alpha),
                    "schedule": "linear",
                    "layers": layer_list,
                    "n": int(gN),
                    "baseline_acc": float(g_base_acc),
                    "steered_acc": float(g_steer_acc),
                    "delta_acc": float(g_steer_acc - g_base_acc),
                    "baseline_gen_tokens_total": int(g_base_tok_total),
                    "steered_gen_tokens_total": int(g_steer_tok_total),
                    "delta_gen_tokens_avg": (g_steer_tok_total - g_base_tok_total) / max(1, int(gN)),
                    "use_strategy": args.use_strategy if mode=="vec_base" else None,
                    "edge_vec_prefix": args.edge_vec_prefix if mode.startswith("soft_") else None,
                }
                results.append(res)
                print(
                    f"[grid-improve:{mode}] model={args.gen_model} "
                    f"| dataset={args.hf_dataset or ('diamond' if args.eval_diamond else 'custom')} "
                    f"| seed={args.seed} | linear | L={layer_list} | α={alpha} "
                    f"→ acc={g_steer_acc:.4f} (Δ={res['delta_acc']:+.4f}), Δtok_avg={res['delta_gen_tokens_avg']:+.2f}"
                )

    # ---------- Save once ----------
    if _only_rank0():
        os.makedirs(os.path.dirname(args.out_json), exist_ok=True)
        with open(args.out_json, "w") as f:
            json.dump({
                "baseline": {"acc": g_base_acc, "n": g_total_N, "gen_tokens_total": g_base_tok_total},
                "runs": results
            }, f, indent=2)
        print(f"[grid] wrote {args.out_json}")

        if args.out_csv:
            os.makedirs(os.path.dirname(args.out_csv), exist_ok=True)
            all_keys = set()
            for r in results: all_keys.update(r.keys())
            preferred = [
                "mode","alpha","schedule","layers","n",
                "baseline_acc","steered_acc","delta_acc",
                "baseline_gen_tokens_total","steered_gen_tokens_total","delta_gen_tokens_avg",
                "use_strategy","edge_vec_prefix"
            ]
            cols = [k for k in preferred if k in all_keys] + [k for k in sorted(all_keys) if k not in preferred]
            with open(args.out_csv, "w", newline="") as f:
                w = csv.DictWriter(f, fieldnames=cols, extrasaction="ignore")
                w.writeheader()
                for r in results:
                    row = {}
                    for k in cols:
                        v = r.get(k)
                        if k == "layers" and isinstance(v, (list, tuple)):
                            v = ",".join(map(str, v))
                        row[k] = v
                    w.writerow(row)
            print(f"[grid] wrote {args.out_csv}")

# others
if __name__ == "__main__":
    try:
        main()
    finally:
        try:
            import torch.distributed as dist
            if _dist_is_enabled():
                dist.barrier(); dist.destroy_process_group()
        except Exception:
            pass

# aime
# if __name__ == "__main__":
#     import os, datetime
#     import torch
#     import torch.distributed as dist

#     # --- init DDP only when launched via torchrun ---
#     ddp = ("RANK" in os.environ and "WORLD_SIZE" in os.environ)
#     if ddp and not dist.is_initialized():
#         # optional: pin the GPU for this rank
#         if torch.cuda.is_available() and "LOCAL_RANK" in os.environ:
#             torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

#         dist.init_process_group(
#             backend="nccl",                       # "nccl" on GPU, "gloo" if CPU-only
#             timeout=datetime.timedelta(hours=2),  # <- bump beyond your longest gen
#         )

#     try:
#         main()
#     finally:
#         # Clean shutdown; DO NOT barrier() here (it can hang if a rank failed)
#         try:
#             if ddp and dist.is_initialized():
#                 dist.destroy_process_group()
#         except Exception as e:
#             print(f"[cleanup] destroy_process_group warning: {e}", flush=True)
